Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Segment Anything 2 #8243

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from

Conversation

jeanchristopheruel
Copy link

@jeanchristopheruel jeanchristopheruel commented Jul 31, 2024

Motivation and context

Regarding #8230 and #8231, I added support for the Segment Anything 2.0 as a Nuclio serverless function. The original Facebook Research repository required some modifications (see pull request) to ease the integration with Nuclio.

Note [EDITED]: This is GPU and CPU.

EDIT: Additional efforts are required to enhance the annotation experience, making it faster by decoding the embeddings client-side with onnxruntime-web. See this comment.

How has this been tested?

The changes were tested on a machine with a GPU and CUDA installed. I verified that the Nuclio function deployed correctly and was able to perform segmentation tasks using Segment Anything 2.0. The integration was tested by running various segmentation tasks and ensuring the expected output was generated. Additionally, the function's performance was monitored to ensure it operated efficiently within the Nuclio environment.

Checklist

  • I submit my changes into the develop branch
  • I have created a changelog fragment
  • I have updated the documentation accordingly
  • [-] I have added tests to cover my changes
  • [-] I have linked related issues (see GitHub docs)
  • [-] I have increased versions of npm packages if it is necessary
    (cvat-canvas,
    cvat-core,
    cvat-data and
    cvat-ui)

License

  • I submit my code changes under the same MIT License that covers the project.
    Feel free to contact the maintainers if that's a concern.

Segment Anything 2.0 require to compile a .cu file with nvcc at build time. Hence, a cuda devel baseImage is required to build the nuclio container.
Copy link
Contributor

coderabbitai bot commented Jul 31, 2024

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

The recent update enhances the documentation and functionality of a serverless image segmentation service using the Segment Anything 2.0 model. Key changes include the addition of a new entry in the README.md, the introduction of configuration and processing scripts for serverless deployment, and improvements for GPU optimization. Overall, these changes streamline the integration of advanced segmentation capabilities, making it more accessible for developers.

Changes

Files and Folders Change Summary
README.md Added a new entry for "Segment Anything 2.0" algorithm including type, framework, and support information.
changelog.d/20240731_000641_ruelj2.md Introduced features for integrating the Segment Anything 2.0 model as a Nuclio serverless function optimized for GPU.
serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml Created YAML config for a Nuclio function specifying metadata, runtime, handler, triggers, and resource limits.
serverless/pytorch/facebookresearch/sam2/nuclio/main.py Implemented the serverless function with init_context and handler for image processing and mask generation.
serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py Added ModelHandler class for structured image prediction handling using the SAM2 model, optimizing inference processes.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant HTTPTrigger
    participant MainHandler
    participant ModelHandler

    User->>HTTPTrigger: Send image and points
    HTTPTrigger->>MainHandler: Forward request
    MainHandler->>MainHandler: Initialize context
    MainHandler->>ModelHandler: Process image with points
    ModelHandler->>ModelHandler: Generate mask
    ModelHandler-->>MainHandler: Return mask
    MainHandler-->>HTTPTrigger: Send response with mask
    HTTPTrigger-->>User: Display result
Loading

🐰 In fields so bright and green,
A new feature’s been seen!
With masks and models, oh so grand,
Image magic at your hand.
Hop along, let’s code away,
Segment anything, hip-hip-hooray! 🎉


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Outside diff range, codebase verification and nitpick comments (2)
changelog.d/20240731_000641_ruelj2.md (1)

3-3: Improve conciseness by removing redundant adverbs.

The phrase "currently fully" can be simplified for clarity.

- This fork currently fully supports Segment Anything 2.0 on GPU with the large model.
+ This fork supports Segment Anything 2.0 on GPU with the large model.
Tools
LanguageTool

[style] ~3-~3: Using two adverbs in a row like this may make your sentence unnecessarily wordy. For clarity and conciseness, see if you can omit one of them.
Context: ... the integration with Nuclio. This fork currently fully supports Segment Anything 2.0 on GPU wi...

(LY_LY)

serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (1)

10-15: Ensure GPU availability.

The initialization of SAM2ImagePredictor assumes GPU availability without checking if CUDA is available. This can lead to issues if CUDA is not available. Consider modifying the initialization to handle cases where CUDA is not available.

  • serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py, line 15: self.predictor = SAM2ImagePredictor(build_sam2(self.model_cfg, self.sam_checkpoint, device="cuda"))
Analysis chain

Ensure GPU availability.

The initialization assumes GPU availability. Consider handling cases where CUDA is not available.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify if `torch.cuda.is_available()` is always true in the codebase.

# Test: Search for the usage of `torch.cuda.is_available()`. Expect: Proper handling of cases where CUDA is not available.
rg --type python -A 5 $'torch.cuda.is_available()'

Length of output: 82


Script:

#!/bin/bash
# Description: Verify if `torch.cuda.is_available()` is always true in the codebase.

# Test: Search for the usage of `torch.cuda.is_available()`. Expect: Proper handling of cases where CUDA is not available.
rg -A 5 'torch.cuda.is_available()'

Length of output: 4749

Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 2191ffa and d482fd9.

Files selected for processing (5)
  • README.md (1 hunks)
  • changelog.d/20240731_000641_ruelj2.md (1 hunks)
  • serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (1 hunks)
  • serverless/pytorch/facebookresearch/sam2/nuclio/main.py (1 hunks)
  • serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (1 hunks)
Files skipped from review due to trivial changes (1)
  • README.md
Additional context used
LanguageTool
changelog.d/20240731_000641_ruelj2.md

[style] ~3-~3: Using two adverbs in a row like this may make your sentence unnecessarily wordy. For clarity and conciseness, see if you can omit one of them.
Context: ... the integration with Nuclio. This fork currently fully supports Segment Anything 2.0 on GPU wi...

(LY_LY)

Additional comments not posted (7)
serverless/pytorch/facebookresearch/sam2/nuclio/main.py (1)

16-19: Verify compatibility with Ampere GPUs.

Ensure that the device properties check for Ampere GPUs is correct and covers all necessary conditions.

Verification successful

The compatibility check for Ampere GPUs is correctly implemented.

The code correctly checks if the major version of the CUDA device properties is 8 or higher to identify Ampere GPUs and enables TensorFloat-32 (TF32) support accordingly.

  • serverless/pytorch/facebookresearch/sam2/nuclio/main.py:
    • Line 16: if torch.cuda.get_device_properties(0).major >= 8:
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the compatibility check for Ampere GPUs.

# Test: Search for the usage of `torch.cuda.get_device_properties`. Expect: Proper handling of Ampere GPUs.
rg --type python -A 5 $'torch.cuda.get_device_properties'

Length of output: 87


Script:

#!/bin/bash
# Description: Verify the compatibility check for Ampere GPUs.

# Test: Search for the usage of `torch.cuda.get_device_properties`. Expect: Proper handling of Ampere GPUs.
rg -A 5 'torch.cuda.get_device_properties'

Length of output: 692

serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (6)

5-17: Ensure the min_pos_points and min_neg_points values are correct.

The metadata section defines the minimum number of positive and negative points required. Verify that these values align with the requirements of the Segment Anything 2.0 model.


18-22: Ensure the runtime and handler are correctly defined.

The runtime is set to Python 3.8, and the handler is defined as main:handler. Verify that these values are correct and compatible with the Segment Anything 2.0 model.


24-50: Ensure the base image and build directives are correct.

The base image is set to pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel. Verify that this image is compatible with the Segment Anything 2.0 model and the CUDA version required.

Check the installation of dependencies and weights.

The build directives include installing dependencies and downloading weights. Ensure that these steps are correctly defined and necessary for the Segment Anything 2.0 model.


51-57: Ensure the HTTP trigger attributes are correct.

The HTTP trigger includes attributes like maxWorkers and maxRequestBodySize. Verify that these values are appropriate for the expected workload and data size.


58-60: Ensure the GPU resource limits are correct.

The resource limits specify using one GPU. Verify that this is sufficient for the Segment Anything 2.0 model's requirements.


62-67: Ensure the platform attributes are correct.

The platform section includes a restart policy and mount mode. Verify that these values are appropriate for the Nuclio function's deployment environment.

Comment on lines 17 to 28
def handle(self, image, pos_points, neg_points):
pos_points, neg_points = list(pos_points), list(neg_points)
with torch.inference_mode():
self.predictor.set_image(np.array(image))
masks, scores, logits = self.predictor.predict(
point_coords=np.array(pos_points + neg_points),
point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)),
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
best_mask = masks[sorted_ind][0]
return best_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper error handling.

The handle method does not include error handling. Consider adding try-except blocks to handle potential errors during prediction.

-    def handle(self, image, pos_points, neg_points):
-        pos_points, neg_points = list(pos_points), list(neg_points)
-        with torch.inference_mode():
-            self.predictor.set_image(np.array(image))
-            masks, scores, logits = self.predictor.predict(
-                point_coords=np.array(pos_points + neg_points),
-                point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)),
-                multimask_output=True,
-            )
-            sorted_ind = np.argsort(scores)[::-1]
-            best_mask = masks[sorted_ind][0]
-            return best_mask
+    def handle(self, image, pos_points, neg_points):
+        try:
+            pos_points, neg_points = list(pos_points), list(neg_points)
+            with torch.inference_mode():
+                self.predictor.set_image(np.array(image))
+                masks, scores, logits = self.predictor.predict(
+                    point_coords=np.array(pos_points + neg_points),
+                    point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)),
+                    multimask_output=True,
+                )
+                sorted_ind = np.argsort(scores)[::-1]
+                best_mask = masks[sorted_ind][0]
+                return best_mask
+        except Exception as e:
+            # Handle or log the error as needed
+            raise RuntimeError("Error during prediction") from e
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def handle(self, image, pos_points, neg_points):
pos_points, neg_points = list(pos_points), list(neg_points)
with torch.inference_mode():
self.predictor.set_image(np.array(image))
masks, scores, logits = self.predictor.predict(
point_coords=np.array(pos_points + neg_points),
point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)),
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
best_mask = masks[sorted_ind][0]
return best_mask
def handle(self, image, pos_points, neg_points):
try:
pos_points, neg_points = list(pos_points), list(neg_points)
with torch.inference_mode():
self.predictor.set_image(np.array(image))
masks, scores, logits = self.predictor.predict(
point_coords=np.array(pos_points + neg_points),
point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)),
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
best_mask = masks[sorted_ind][0]
return best_mask
except Exception as e:
# Handle or log the error as needed
raise RuntimeError("Error during prediction") from e

Comment on lines 12 to 15
def init_context(context):
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper cleanup of autocast context.

The torch.autocast context is entered but never exited. Ensure proper cleanup to avoid potential issues.

-    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
+    context.user_data.autocast = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
+    context.user_data.autocast.__enter__()
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def init_context(context):
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
def init_context(context):
# use bfloat16 for the entire notebook
context.user_data.autocast = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
context.user_data.autocast.__enter__()

Comment on lines 25 to 42
def handler(context, event):
context.logger.info("call handler")
data = event.body
buf = io.BytesIO(base64.b64decode(data["image"]))
context.logger.info(f"data: {data}")
image = Image.open(buf)
image = image.convert("RGB") # to make sure image comes in RGB
pos_points = data["pos_points"]
neg_points = data["neg_points"]

mask = context.user_data.model.handle(image, pos_points, neg_points)

return context.Response(
body=json.dumps({ 'mask': mask.tolist() }),
headers={},
content_type='application/json',
status_code=200
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper error handling in handler function.

The handler function does not include error handling. Consider adding try-except blocks to handle potential errors during processing.

-    context.logger.info("call handler")
-    data = event.body
-    buf = io.BytesIO(base64.b64decode(data["image"]))
-    context.logger.info(f"data: {data}")
-    image = Image.open(buf)
-    image = image.convert("RGB")  # to make sure image comes in RGB
-    pos_points = data["pos_points"]
-    neg_points = data["neg_points"]
-
-    mask = context.user_data.model.handle(image, pos_points, neg_points)
-
-    return context.Response(
-        body=json.dumps({ 'mask': mask.tolist() }),
-        headers={},
-        content_type='application/json',
-        status_code=200
-    )
+    try:
+        context.logger.info("call handler")
+        data = event.body
+        buf = io.BytesIO(base64.b64decode(data["image"]))
+        context.logger.info(f"data: {data}")
+        image = Image.open(buf)
+        image = image.convert("RGB")  # to make sure image comes in RGB
+        pos_points = data["pos_points"]
+        neg_points = data["neg_points"]
+
+        mask = context.user_data.model.handle(image, pos_points, neg_points)
+
+        return context.Response(
+            body=json.dumps({ 'mask': mask.tolist() }),
+            headers={},
+            content_type='application/json',
+            status_code=200
+        )
+    except Exception as e:
+        context.logger.error(f"Error processing request: {e}")
+        return context.Response(
+            body=json.dumps({ 'error': str(e) }),
+            headers={},
+            content_type='application/json',
+            status_code=500
+        )
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def handler(context, event):
context.logger.info("call handler")
data = event.body
buf = io.BytesIO(base64.b64decode(data["image"]))
context.logger.info(f"data: {data}")
image = Image.open(buf)
image = image.convert("RGB") # to make sure image comes in RGB
pos_points = data["pos_points"]
neg_points = data["neg_points"]
mask = context.user_data.model.handle(image, pos_points, neg_points)
return context.Response(
body=json.dumps({ 'mask': mask.tolist() }),
headers={},
content_type='application/json',
status_code=200
)
def handler(context, event):
try:
context.logger.info("call handler")
data = event.body
buf = io.BytesIO(base64.b64decode(data["image"]))
context.logger.info(f"data: {data}")
image = Image.open(buf)
image = image.convert("RGB") # to make sure image comes in RGB
pos_points = data["pos_points"]
neg_points = data["neg_points"]
mask = context.user_data.model.handle(image, pos_points, neg_points)
return context.Response(
body=json.dumps({ 'mask': mask.tolist() }),
headers={},
content_type='application/json',
status_code=200
)
except Exception as e:
context.logger.error(f"Error processing request: {e}")
return context.Response(
body=json.dumps({ 'error': str(e) }),
headers={},
content_type='application/json',
status_code=500
)

@HanClinto
Copy link
Contributor

This looks really good!

What would it take to also integrate the tracking capabilities of SAM 2?

@jeanchristopheruel
Copy link
Author

jeanchristopheruel commented Aug 3, 2024

@HanClinto I think the best approach would be to return the SAM2 memory bank queue to the user or a DB. This way, we could ensure the SAM2 service is stateless. At the moment I ignore the overhead of doing so, but the article states that the memory banks is composed of "spatial feature maps" and " lightweight vectors for high-level semantic information". The spatial feature maps transfer GPU -> CPU -> Network might be a bottleneck here depending on their size.

I can help achieving this.

The article is here

@KTXKIKI
Copy link

KTXKIKI commented Aug 4, 2024

我认为最好的方法是将 SAM2 内存组队列返回给用户或数据库。这样,我们就可以确保 SAM2 服务是无状态的。目前,我忽略了这样做的开销,但文章指出,内存库由“空间特征图”和“用于高级语义信息的轻量级向量”组成。空间特征图传输 GPU -> CPU -> 网络可能是这里的瓶颈,具体取决于它们的大小。

我可以帮助实现这一目标。

文章在这里

I tried but didn't use this, which resulted in every click requiring a resend of the request to reason
#6019

@jeanchristopheruel

@jeanchristopheruel
Copy link
Author

jeanchristopheruel commented Aug 4, 2024

@KTXKIKI you're right. There is currently a big overhead related to the request being sent for each click. I underestimated the request bottleneck, especially for large images. I thought it could be viable, given that SAM2 inference is faster than SAM1.

I'll suggest an improvement tonight.

@jeanchristopheruel
Copy link
Author

jeanchristopheruel commented Aug 4, 2024

@KTXKIKI I wont be able to produce the solution tonight. It would require to write a new cvat_ui plugin to decode the SAM2 embeddings client-side using onnxruntime-web, just like it has been done for SAM1 (cvat-ui/plugins/sam/src/ts/index.tsx). It would also require to export the SAM2 decoder in onnx format.

This thread is an excellent starting point: https://github.com/facebookresearch/segment-anything-2/issues/3

@KTXKIKI
Copy link

KTXKIKI commented Aug 4, 2024

我今晚无法提供解决方案。它需要编写一个新的 cvat_ui 插件来使用 onnxruntime-web 在客户端解码 SAM2 嵌入,就像对 SAM1 所做的那样 (cvat-ui/plugins/sam/src/ts/index.tsx)。它还需要以 onnx 格式导出 SAM2 解码器。

这个线程是一个很好的起点:facebookresearch/segment-anything-2#3

I think we need the help of official CVAT personnel

@ozangungor12
Copy link

Hi @jeanchristopheruel, thanks for your great work. As far as I understood, you added SAM2 as an interactor tool, which is working the same way as SAM does. However, the biggest improvement of SAM2 is the video tracking. Even if we somehow implement SAM2 as a tracker, CVAT UI would require us to manually go to the next frame one by one. But SAM2 video tracker is capable of tracking the object over the whole video after the first frame and points are selected. Do you know if it's possible to merge that functionality with CVAT? Is that supported somewhere in the UI at all?

@jeanchristopheruel
Copy link
Author

jeanchristopheruel commented Aug 6, 2024

@ozangungor12 It is possible to integrate SAM2 for video tracking with its featured memory embeddings. It would require to write a new cvat_ui plugin to decode the SAM2 embeddings (encoded featuremap & memory bank) client-side using onnxruntime-web. This would allow the full serverless compatibility with nuclio and ensure scalability (stateless) for cloud applications.

Alternatively, for your own interest, you can modify the SAM2 model_handler.py to maintain the state of the last processed image and you can add a REST endpoint to clear the state on demand. This alternative is NOT clean and should be used only within a single session.

@realtimshady1
Copy link

Hi, this looks great. I noticed that there isn't a function.yaml for serverless without gpu. Any reason for that?

@ozangungor12
Copy link

Hi, this looks great. I noticed that there isn't a function.yaml for serverless without gpu. Any reason for that?

I don't think SAM2 can work without a GPU.

@jeanchristopheruel also said in the PR:

Note: Segment Anything 2.0 require to compile a .cu file with nvcc at build time. Hence, a cuda devel baseImage is required to build the nuclio container and no support available for CPU. This is GPU only.

@bhack
Copy link

bhack commented Aug 9, 2024

I don't think SAM2 can work without a GPU.

See the thread at https://github.com/facebookresearch/segment-anything-2/pull/155

@ozangungor12
Copy link

I don't think SAM2 can work without a GPU.

See the thread at facebookresearch/segment-anything-2#155

Great, thanks for sending it!

@jeanchristopheruel
Copy link
Author

I don't think SAM2 can work without a GPU.

See the thread at facebookresearch/segment-anything-2#155

@ozangungor12, @bhack and @realtimshady1, I added support for cpu based on this. Thanks for the info.

@nmanovic
Copy link
Contributor

@jeanchristopheruel , thank you for the PR. Could you please look at linters?

Copy link

sonarcloud bot commented Aug 15, 2024

@nmanovic
Copy link
Contributor

@jeanchristopheruel , we will be happy to merge the version of SAM2 into CVAT open-source repository. Need to say that our team implemented optimized version of SAM2: https://www.cvat.ai/post/meta-segment-anything-model-v2-is-now-available-in-cvat-ai. It will be available on SaaS for our paid customers and for Enterprise customers.

@jeanchristopheruel
Copy link
Author

@nmanovic Thanks for your response. However, I’m disappointed to see that key advancements like SAM2 are becoming restricted to paid users. CVAT has always been a strong open-source tool, and limiting such features seems to move away from that spirit. I hope you will reconsider and keep these innovations accessible to the broader open-source community.

@nmanovic
Copy link
Contributor

@jeanchristopheruel , I would make all features open-source if it were possible. However, delivering new and innovative features to the open-source repository, such as the YOLOv8 format support (#8240), and addressing security issues and bugs, requires financial backing. To sustain this level of development, we rely on the support of paying customers. The best way to help CVAT continue thriving is by purchasing a SaaS subscription (https://www.cvat.ai/pricing/cloud) or becoming an Enterprise customer (https://www.cvat.ai/pricing/on-prem).

It's worth noting that around 80% of our contributions go directly into the open-source repository.

@jeanchristopheruel
Copy link
Author

jeanchristopheruel commented Aug 16, 2024

@nmanovic, I understand the need for financial support to sustain development, and I appreciate all the work your team does. However, history has shown that moving key features behind paywalls can sometimes alienate open-source communities. For example, when Elasticsearch restricted features, it led to the community forking it into OpenSearch.

I hope CVAT can find a balance that supports both its financial needs and keeps innovation accessible to the open-source community, as that's what has made CVAT so valuable to so many. 😌

@jeanchristopheruel
Copy link
Author

jeanchristopheruel commented Aug 17, 2024

For those with stronger frontend skills, I recommend checking out this repository, which contains a complete frontend implementation of SAM2 using onnxruntime-web/webgpu.

I also attempted a frontend implementation, and you can find my initial trial here. It's still a work in progress, but feel free to take a look.

@jeanchristopheruel jeanchristopheruel changed the title Introduce Segment Anything 2.0 Introduce Segment Anything 2 Aug 17, 2024
@Youho99
Copy link

Youho99 commented Aug 20, 2024

Another possible and very useful feature with models like SAM and SAM2 would be precision annotation in bounding boxes.

The idea is to make an imprecise bounding box around the object to be annotated. The bounding box is sent to the SAM or SAM2 model, which segments the main object from the bounding box it receives. Finally, the precise bounding box is recreated by taking the extremum coordinates at the top, left, bottom, right.

This would allow very quick and precise annotating, without having to zoom in on the image (very useful for precise annotation of small objects for example).

In my free time, I made a python script using this logic with SAM to make precision annotation, taking as input an annotation json (COCO format I think) and which output a json in the same format, with the precise bounding boxes recalculated.

I could make it available to you if necessary.

@jeanchristopheruel
Copy link
Author

@Youho99 Very cool indeed! I suggest you create a separate issue to express your feature idea.🙂

@Youho99
Copy link

Youho99 commented Aug 21, 2024

@Youho99 Very cool indeed! I suggest you create a separate issue to express your feature idea.🙂

Done ✔️ #8326

@tpsvoets
Copy link

Great! Do you have an estimate when the tracking aspect / video annotation aspect will be implemented?

@jeanchristopheruel
Copy link
Author

jeanchristopheruel commented Aug 21, 2024

For those with stronger frontend skills, I recommend checking out this repository, which contains a complete frontend implementation of SAM2 using onnxruntime-web/webgpu.

I also attempted a frontend implementation, and you can find my initial trial here. It's still a work in progress, but feel free to take a look.

@tpsvoets The current PR adds support for an encoder-decoder sam2 backend, which makes the thing slower than sam1 plugin due to the request overhead. (Sam1 plugin has the decoder running in frontend).

Can't give a timeline for sam2 encoder-decoder frontend support since I am not currently working on it. Maybe in the next year..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants